/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2022-2024. All rights reserved.
 */
#include "geometric_kernel_attn_grad_tiling.h"
#include "register/op_def_registry.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/tiling_api.h"
#include "ge/utils.h"

using namespace ge;
using namespace std;

namespace {
const uint32_t POS_INPUT_VALUE = 0;
const uint32_t POS_INPUT_SPATIAL_SHAPES = 1;
const uint32_t POS_INPUT_LEVEL_START_INDEX = 2;
const uint32_t POS_INPUT_SAMPLING_LOCATIONS = 3;
const uint32_t POS_INPUT_ATTN_WEIGHTS = 4;
const uint32_t POS_INPUT_GRAD_OUTPUT = 5;
const uint32_t POS_OUTPUT_GRAD_VALUE = 0;
const uint32_t POS_OUTPUT_GRAD_ATTN_WEIGHTS = 1;
const uint32_t POS_ATTR_NUM_POINTS_REAL = 0;
const uint32_t VALUE_BATCH_SIZE_DIM = 0;
const uint32_t VALUE_NUM_KEYS_DIM = 1;
const uint32_t VALUE_EMBED_DIMS_DIM = 2;
const uint32_t ATTN_WEIGHTS_NUM_LEVELS_DIM = 0;
const uint32_t ATTN_WEIGHTS_BATCH_SIZE_DIM = 1;
const uint32_t ATTN_WEIGHTS_NUM_QUERIES_DIM = 2;
const uint32_t ATTN_WEIGHTS_NUM_POINTS_DIM = 3;
const uint64_t UB_RESERVE_BYTES = 10 * 1024;
const uint32_t FLOAT32_BYTES = 4;
const uint32_t BLOCK_BYTES = 32;
} // namespace

namespace optiling {
static ge::graphStatus TilingFuncForGeometricKernelAttnGrad(gert::TilingContext* context)
{
    GeometricKernelAttnGradTilingData tiling;

    auto valueTensorPtr = context->GetInputTensor(POS_INPUT_VALUE);
    auto attnWeightsTensorPtr = context->GetInputTensor(POS_INPUT_ATTN_WEIGHTS);
    if (valueTensorPtr == nullptr || attnWeightsTensorPtr == nullptr) {
        return ge::GRAPH_FAILED;
    }
    auto valueShape = valueTensorPtr->GetStorageShape();
    auto attnWeightShape = attnWeightsTensorPtr->GetStorageShape();

    auto platformInfoPtr = context->GetPlatformInfo();
    if (platformInfoPtr == nullptr) {
        return ge::GRAPH_FAILED;
    }
    auto ascendPlatformInfo = platform_ascendc::PlatformAscendC(platformInfoPtr);
    auto aicNum = ascendPlatformInfo.GetCoreNumAic();
    auto aivNum = ascendPlatformInfo.GetCoreNumAiv();
    if (aicNum == 0 || aivNum == 0) {
        return ge::GRAPH_FAILED;
    }
    context->SetBlockDim(aicNum);

    uint32_t batchSize = valueShape.GetDim(VALUE_BATCH_SIZE_DIM);
    uint32_t embedDims = valueShape.GetDim(VALUE_EMBED_DIMS_DIM);
    uint32_t numKeys = valueShape.GetDim(VALUE_NUM_KEYS_DIM);
    uint32_t numLevels = attnWeightShape.GetDim(ATTN_WEIGHTS_NUM_LEVELS_DIM);
    uint32_t numQueries = attnWeightShape.GetDim(ATTN_WEIGHTS_NUM_QUERIES_DIM);
    uint32_t numPoints = attnWeightShape.GetDim(ATTN_WEIGHTS_NUM_POINTS_DIM);

    uint32_t numItemsPerBlock = BLOCK_BYTES / FLOAT32_BYTES;
    uint32_t numLevelsAligned = AlignUp(numLevels, numItemsPerBlock);
    uint32_t numKeysAligned = AlignUp(numKeys, numItemsPerBlock);
    uint32_t numPointsAligned = AlignUp(numPoints, numItemsPerBlock);

    uint32_t numLargeCores = numQueries % aivNum;
    if (numLargeCores == 0) {
        numLargeCores = aivNum;
    }

    uint64_t ubBytesTotal;
    ascendPlatformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubBytesTotal);
    uint64_t ubBytes = ubBytesTotal - UB_RESERVE_BYTES;
    uint64_t ubSize = ubBytes / FLOAT32_BYTES;
    uint32_t ubSize4Others = numLevelsAligned * 3 + numPointsAligned * embedDims * 2;
    uint32_t ubSize4OthersOpt = ubSize4Others + numKeys * embedDims;
    uint32_t oneQuerySizePerBundle = numKeysAligned + embedDims + numPointsAligned * 3;

    uint32_t ubSize4Bundle = ubSize - ubSize4OthersOpt;
    if (ubSize4OthersOpt + 2 * oneQuerySizePerBundle < ubSize) {
        context->SetTilingKey(1);
    } else {
        context->SetTilingKey(0);
        ubSize4Bundle = ubSize - ubSize4Others;
    }

    uint32_t numQueriesPerBundle = (ubSize4Bundle - numItemsPerBlock) / oneQuerySizePerBundle;
    uint32_t numQueriesPerLargeCore = (numQueries + aivNum - 1) / aivNum;

    matmul_tiling::MatmulApiTiling mmTiling(ascendPlatformInfo);
    mmTiling.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT);
    mmTiling.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType::DT_FLOAT, true);
    mmTiling.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND_ALIGN, matmul_tiling::DataType::DT_FLOAT);
    mmTiling.SetShape(numQueriesPerLargeCore, numKeys, embedDims);
    mmTiling.SetOrgShape(numQueriesPerLargeCore, numKeys, embedDims);
    mmTiling.SetBias(false);
    mmTiling.SetBufferSpace(-1, -1, -1);
    if (mmTiling.GetTiling(tiling.mmTilingData) == -1) {
        return ge::GRAPH_FAILED;
    }

    tiling.set_batchSize(batchSize);
    tiling.set_embedDims(embedDims);
    tiling.set_numKeys(numKeys);
    tiling.set_numLevels(numLevels);
    tiling.set_numQueries(numQueries);
    tiling.set_numPoints(numPoints);
    tiling.set_coreNum(aivNum);
    tiling.set_numLargeCores(numLargeCores);
    tiling.set_numQueriesPerBundle(numQueriesPerBundle);
    tiling.set_numQueriesPerLargeCore(numQueriesPerLargeCore);

    if (context->GetRawTilingData() == nullptr) {
        return ge::GRAPH_FAILED;
    }
    tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity());
    context->GetRawTilingData()->SetDataSize(tiling.GetDataSize());

    size_t systemWorkspaceSize = ascendPlatformInfo.GetLibApiWorkSpaceSize();
    size_t tmp4GradAttnWeightsSize = numQueries * numKeysAligned * sizeof(float);
    size_t* currentWorkspace = context->GetWorkspaceSizes(1);
    currentWorkspace[0] = systemWorkspaceSize + tmp4GradAttnWeightsSize;

    return ge::GRAPH_SUCCESS;
}
} // namespace optiling

namespace ge {
static ge::graphStatus InferShapeForGeometricKernelAttnGrad(gert::InferShapeContext* context)
{
    const gert::Shape* valueShape = context->GetInputShape(POS_INPUT_VALUE);
    const gert::Shape* attnWeightsShape = context->GetInputShape(POS_INPUT_ATTN_WEIGHTS);
    if (valueShape == nullptr || attnWeightsShape == nullptr) {
        return ge::GRAPH_FAILED;
    }

    gert::Shape* gradValueShape = context->GetOutputShape(POS_OUTPUT_GRAD_VALUE);
    gert::Shape* gradAttnWeightsShape = context->GetOutputShape(POS_OUTPUT_GRAD_ATTN_WEIGHTS);
    if ((gradValueShape == nullptr) || (gradAttnWeightsShape == nullptr)) {
        return ge::GRAPH_FAILED;
    }

    int64_t batchSize = valueShape->GetDim(VALUE_BATCH_SIZE_DIM);
    int64_t numKeys = valueShape->GetDim(VALUE_NUM_KEYS_DIM);
    int64_t embedDims = valueShape->GetDim(VALUE_EMBED_DIMS_DIM);
    int64_t numLevels = attnWeightsShape->GetDim(ATTN_WEIGHTS_NUM_LEVELS_DIM);
    int64_t numQueries = attnWeightsShape->GetDim(ATTN_WEIGHTS_NUM_QUERIES_DIM);
    int64_t numPoints = attnWeightsShape->GetDim(ATTN_WEIGHTS_NUM_POINTS_DIM);

    *gradValueShape = {batchSize, numKeys, embedDims};
    *gradAttnWeightsShape = {numLevels, batchSize, numQueries, numPoints};
    return GRAPH_SUCCESS;
}

static ge::graphStatus InferDataTypeForGeometricKernelAttnGrad(gert::InferDataTypeContext* context)
{
    const ge::DataType value_dtype = context->GetInputDataType(POS_INPUT_VALUE);
    const ge::DataType attn_weights_dtype = context->GetInputDataType(POS_INPUT_ATTN_WEIGHTS);
    context->SetOutputDataType(POS_OUTPUT_GRAD_VALUE, value_dtype);
    context->SetOutputDataType(POS_OUTPUT_GRAD_ATTN_WEIGHTS, attn_weights_dtype);
    return GRAPH_SUCCESS;
}
} // namespace ge

namespace ops {
class GeometricKernelAttnGrad : public OpDef {
public:
    explicit GeometricKernelAttnGrad(const char* name) : OpDef(name)
    {
        this->Input("value")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND})
            .AutoContiguous();
        this->Input("spatial_shapes")
            .ParamType(REQUIRED)
            .DataType({ge::DT_INT32})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND})
            .AutoContiguous();
        this->Input("level_start_index")
            .ParamType(REQUIRED)
            .DataType({ge::DT_INT32})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND})
            .AutoContiguous();
        this->Input("sampling_locations")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND})
            .AutoContiguous();
        this->Input("attn_weights")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND})
            .AutoContiguous();
        this->Input("grad_output")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND})
            .AutoContiguous();
        this->Output("grad_value")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});
        this->Output("grad_attn_weights")
            .ParamType(REQUIRED)
            .DataType({ge::DT_FLOAT})
            .Format({ge::FORMAT_ND})
            .UnknownShapeFormat({ge::FORMAT_ND});

        this->SetInferShape(ge::InferShapeForGeometricKernelAttnGrad)
            .SetInferDataType(ge::InferDataTypeForGeometricKernelAttnGrad);
        this->AICore().SetTiling(optiling::TilingFuncForGeometricKernelAttnGrad);
        this->AICore().AddConfig("ascend910b");
        this->AICore().AddConfig("ascend910_93");
    }
};

OP_ADD(GeometricKernelAttnGrad);
} // namespace ops
